package config

import (
	"fmt"
	"os"
	"strings"
	"testing"

	"github.com/oauth2-proxy/mockoidc"
	"github.com/stretchr/testify/assert"
)

// TODO: add more tests, including yaml- and env. parsing, validation, etc.

func Test_Load_OidcProviders(t *testing.T) {
	oidcMock1, _ := mockoidc.Run()
	defer oidcMock1.Shutdown()
	oidcMock2, _ := mockoidc.Run()
	defer oidcMock2.Shutdown()

	os.Setenv("WAKAPI_OIDC_PROVIDERS_0_NAME", "testprovider1")
	os.Setenv("WAKAPI_OIDC_PROVIDERS_0_DISPLAY_NAME", "Test Provider 1")
	os.Setenv("WAKAPI_OIDC_PROVIDERS_0_CLIENT_ID", oidcMock1.ClientID)
	os.Setenv("WAKAPI_OIDC_PROVIDERS_0_CLIENT_SECRET", oidcMock1.ClientSecret)
	os.Setenv("WAKAPI_OIDC_PROVIDERS_0_ENDPOINT", oidcMock1.Addr()+"/oidc")
	os.Setenv("WAKAPI_OIDC_PROVIDERS_1_NAME", "testprovider2")
	os.Setenv("WAKAPI_OIDC_PROVIDERS_1_CLIENT_ID", oidcMock2.ClientID)
	os.Setenv("WAKAPI_OIDC_PROVIDERS_1_CLIENT_SECRET", oidcMock2.ClientSecret)
	os.Setenv("WAKAPI_OIDC_PROVIDERS_1_ENDPOINT", oidcMock2.Addr()+"/oidc")

	cfg := Load("", "")
	oidcCfg := cfg.Security.OidcProviders

	assert.Len(t, oidcCfg, 2)
	assert.Equal(t, "testprovider1", oidcCfg[0].Name)
	assert.Equal(t, "Test Provider 1", oidcCfg[0].DisplayName)
	assert.Equal(t, "Test Provider 1", oidcCfg[0].String())
	assert.Equal(t, oidcMock1.ClientID, oidcCfg[0].ClientID)
	assert.Equal(t, oidcMock1.ClientSecret, oidcCfg[0].ClientSecret)
	assert.Equal(t, oidcMock1.Addr()+"/oidc", oidcCfg[0].Endpoint)
	assert.Equal(t, "testprovider2", oidcCfg[1].Name)
	assert.Equal(t, "", oidcCfg[1].DisplayName)
	assert.Equal(t, "Testprovider2", oidcCfg[1].String())
	assert.Equal(t, oidcMock2.ClientID, oidcCfg[1].ClientID)
	assert.Equal(t, oidcMock2.ClientSecret, oidcCfg[1].ClientSecret)
	assert.Equal(t, oidcMock2.Addr()+"/oidc", oidcCfg[1].Endpoint)

	p1, err1 := GetOidcProvider("testprovider1")
	assert.Nil(t, err1)
	assert.Equal(t, "Test Provider 1", p1.DisplayName)

	p2, err2 := GetOidcProvider("testprovider2")
	assert.Nil(t, err2)
	assert.Equal(t, "Testprovider2", p2.DisplayName)
}

func TestOidcProviderConfig_Validate(t *testing.T) {
	// note: test cases were generated by ai
	testCases := []struct {
		name   string
		config oidcProviderConfig
		err    string
	}{
		{
			name: "valid",
			config: oidcProviderConfig{
				Name:         "test-provider-1",
				ClientID:     "client-id",
				ClientSecret: "client-secret",
				Endpoint:     "https://provider.com/oidc",
			},
			err: "",
		},
		{
			name: "valid with http",
			config: oidcProviderConfig{
				Name:         "test-provider-1",
				ClientID:     "client-id",
				ClientSecret: "client-secret",
				Endpoint:     "http://provider.com/oidc",
			},
			err: "",
		},
		{
			name: "invalid name with spaces",
			config: oidcProviderConfig{
				Name: "test provider",
			},
			err: "invalid provider name 'test provider', must only contain alphanumeric characters or '-'",
		},
		{
			name: "invalid name with underscore",
			config: oidcProviderConfig{
				Name: "test_provider",
			},
			err: "invalid provider name 'test_provider', must only contain alphanumeric characters or '-'",
		},
		{
			name: "missing client id",
			config: oidcProviderConfig{
				Name:         "test-provider",
				ClientSecret: "client-secret",
				Endpoint:     "https://provider.com/oidc",
			},
			err: "provider 'test-provider' is missing client id",
		},
		{
			name: "missing client secret",
			config: oidcProviderConfig{
				Name:     "test-provider",
				ClientID: "client-id",
				Endpoint: "https://provider.com/oidc",
			},
			err: "provider 'test-provider' is missing client secret",
		},
		{
			name: "missing endpoint",
			config: oidcProviderConfig{
				Name:         "test-provider",
				ClientID:     "client-id",
				ClientSecret: "client-secret",
			},
			err: "provider 'test-provider' is missing endpoint",
		},
		{
			name: "invalid endpoint scheme",
			config: oidcProviderConfig{
				Name:         "test-provider",
				ClientID:     "client-id",
				ClientSecret: "client-secret",
				Endpoint:     "ftp://provider.com/oidc",
			},
			err: "provider 'test-provider' is missing endpoint",
		},
		{
			name: "endpoint without scheme",
			config: oidcProviderConfig{
				Name:         "test-provider",
				ClientID:     "client-id",
				ClientSecret: "client-secret",
				Endpoint:     "provider.com/oidc",
			},
			err: "provider 'test-provider' is missing endpoint",
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			err := tc.config.Validate()
			if tc.err == "" {
				assert.NoError(t, err)
			} else {
				assert.EqualError(t, err, tc.err)
			}
		})
	}
}

func TestConfig_IsDev(t *testing.T) {
	assert.True(t, IsDev("dev"))
	assert.True(t, IsDev("development"))
	assert.False(t, IsDev("prod"))
	assert.False(t, IsDev("production"))
	assert.False(t, IsDev("anything else"))
}

func Test_mysqlConnectionString(t *testing.T) {
	c := &dbConfig{
		Host:     "test_host",
		Port:     9999,
		User:     "test_user",
		Password: "test_password",
		Name:     "test_name",
		Dialect:  "mysql",
		Charset:  "utf8mb4",
		MaxConn:  10,
		Compress: true,
	}

	assert.Equal(t, fmt.Sprintf(
		"%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=%s&compress=true&sql_mode=ANSI_QUOTES",
		c.User,
		c.Password,
		c.Host,
		c.Port,
		c.Name,
		"Local",
	), mysqlConnectionString(c))
}

func Test_mysqlConnectionStringSocket(t *testing.T) {
	c := &dbConfig{
		Socket:   "/var/run/mysql.sock",
		Port:     9999,
		User:     "test_user",
		Password: "test_password",
		Name:     "test_name",
		Dialect:  "mysql",
		Charset:  "utf8mb4",
		MaxConn:  10,
		Compress: true,
	}

	assert.Equal(t, fmt.Sprintf(
		"%s:%s@unix(%s)/%s?charset=utf8mb4&parseTime=true&loc=%s&compress=true&sql_mode=ANSI_QUOTES",
		c.User,
		c.Password,
		c.Socket,
		c.Name,
		"Local",
	), mysqlConnectionString(c))
}

func Test_postgresConnectionString(t *testing.T) {
	c := &dbConfig{
		Host:     "test_host",
		Port:     9999,
		User:     "test_user",
		Password: "test_password",
		Name:     "test_name",
		Dialect:  "postgres",
		MaxConn:  10,
	}

	assert.Equal(t, fmt.Sprintf(
		"host=%s port=%d user=%s dbname=%s password=%s sslmode=disable",
		c.Host,
		c.Port,
		c.User,
		c.Name,
		c.Password,
	), postgresConnectionString(c))
}

func Test_sqliteConnectionString(t *testing.T) {
	c := &dbConfig{
		Name:    "test_name",
		Dialect: "sqlite3",
	}
	assert.True(t, strings.HasPrefix(sqliteConnectionString(c), c.Name))
	assert.Contains(t, strings.ToLower(sqliteConnectionString(c)), "journal_mode=wal")
}
